{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Non-linear classification models\n",
    "\n",
    "Run this notebook in Vertex Workbench. In this notebook, we will start from the same features as \n",
    "in the [logistic regression notebook](bqml_logistic.ipynb) but use non-linear machine learning methods.\n",
    "\n",
    "The models in this notebook will take longer to train than the linear models.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## xgboost\n",
    "\n",
    "xgboost is usually a very good model for structured data. It's a good next step after logistic regression.\n",
    "This will take ~10 minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing query with job ID: 9984af86-cf61-455d-8408-9a3f3b77b81e\n",
      "Query executing: 183.61s"
     ]
    }
   ],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgboost\n",
    "OPTIONS(input_label_cols=['ontime'], \n",
    "        model_type='boosted_tree_classifier',\n",
    "        data_split_method='custom',\n",
    "        data_split_col='is_eval_day')\n",
    "AS\n",
    "\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And evaluate this model as before"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1494.23query/s]                        \n",
      "Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.03s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>correct_cancel</th>\n",
       "      <th>total_noncancel</th>\n",
       "      <th>correct_noncancel</th>\n",
       "      <th>total_cancel</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.838981</td>\n",
       "      <td>1304078</td>\n",
       "      <td>0.964965</td>\n",
       "      <td>283750</td>\n",
       "      <td>0.207227</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   correct_cancel  total_noncancel  correct_noncancel  total_cancel      rmse\n",
       "0        0.838981          1304078           0.964965        283750  0.207227"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "\n",
    "WITH predictions AS (\n",
    "SELECT \n",
    "  *\n",
    "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))),\n",
    "\n",
    "stats AS (\n",
    "SELECT \n",
    "  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
    "  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
    "  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
    "  , COUNTIF(ontime != 'ontime') AS total_cancel\n",
    "  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
    "FROM predictions, UNNEST(predicted_ontime_probs) p\n",
    "WHERE p.label = 'ontime'\n",
    ")\n",
    "\n",
    "SELECT\n",
    "   correct_cancel / total_cancel AS correct_cancel\n",
    "   , total_noncancel\n",
    "   , correct_noncancel / total_noncancel AS correct_noncancel\n",
    "   , total_cancel\n",
    "   , rmse\n",
    "FROM stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 427.95query/s]                          \n",
      "Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.18s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>predicted_ontime</th>\n",
       "      <th>predicted_ontime_probs</th>\n",
       "      <th>dep_delay</th>\n",
       "      <th>taxi_out</th>\n",
       "      <th>distance</th>\n",
       "      <th>origin</th>\n",
       "      <th>dest</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ontime</td>\n",
       "      <td>[{'label': 'ontime', 'prob': 0.868686914443969...</td>\n",
       "      <td>12.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>802</td>\n",
       "      <td>DFW</td>\n",
       "      <td>ORD</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  predicted_ontime                             predicted_ontime_probs  \\\n",
       "0           ontime  [{'label': 'ontime', 'prob': 0.868686914443969...   \n",
       "\n",
       "   dep_delay  taxi_out  distance origin dest  \n",
       "0       12.0      14.0       802    DFW  ORD  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,\n",
    "                        (\n",
    "SELECT 12.0 AS dep_delay, 14.0 AS taxi_out, 802 AS distance, 'DFW' AS origin, 'ORD' AS dest\n",
    "                        ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter tuning [Optional]\n",
    "\n",
    "Let's tune two things: the MAX_TREE_DEPTH (default=6) and L2 regularization (default=1.0).\n",
    "\n",
    "**This section will take ~60 minutes. You can skip it.**\n",
    "\n",
    "Note that is_eval_day is now a string column with 3 possible values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing query with job ID: bea5283d-efdc-444c-8fee-3ef81f435d80\n",
      "Query executing: 3554.37s"
     ]
    }
   ],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgh\n",
    "OPTIONS(input_label_cols=['ontime'], \n",
    "        model_type='boosted_tree_classifier',\n",
    "        num_trials=5, l2_reg=hparam_range(0.5, 3.0), max_tree_depth=hparam_range(2, 10),\n",
    "        data_split_method='custom',\n",
    "        data_split_col='is_eval_day')\n",
    "AS\n",
    "\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', \n",
    "     IF(RAND() < 0.8, 'TRAIN', 'EVAL'), \n",
    "     'TEST') AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 2/2 [00:00<00:00, 911.51query/s]                         \n",
      "Downloading: 100%|██████████| 3/3 [00:01<00:00,  2.95rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>l2_reg</th>\n",
       "      <th>max_tree_depth</th>\n",
       "      <th>eval_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.536659</td>\n",
       "      <td>10</td>\n",
       "      <td>0.155262</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2.113224</td>\n",
       "      <td>10</td>\n",
       "      <td>0.155313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.887189</td>\n",
       "      <td>10</td>\n",
       "      <td>0.155314</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     l2_reg  max_tree_depth  eval_loss\n",
       "0  2.536659              10   0.155262\n",
       "1  2.113224              10   0.155313\n",
       "2  0.887189              10   0.155314"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT hyperparameters.l2_reg, hyperparameters.max_tree_depth, eval_loss\n",
    "FROM ML.TRIAL_INFO(MODEL dsongcp.arr_delay_airports_xgh)\n",
    "ORDER BY eval_loss ASC LIMIT 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1079.99query/s]                        \n",
      "Downloading: 100%|██████████| 1/1 [00:02<00:00,  2.70s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>correct_cancel</th>\n",
       "      <th>total_noncancel</th>\n",
       "      <th>correct_noncancel</th>\n",
       "      <th>total_cancel</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.841952</td>\n",
       "      <td>1305703</td>\n",
       "      <td>0.965654</td>\n",
       "      <td>283750</td>\n",
       "      <td>0.204358</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   correct_cancel  total_noncancel  correct_noncancel  total_cancel      rmse\n",
       "0        0.841952          1305703           0.965654        283750  0.204358"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "\n",
    "WITH predictions AS (\n",
    "SELECT \n",
    "  *\n",
    "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgh,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))),\n",
    "\n",
    "stats AS (\n",
    "SELECT \n",
    "  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
    "  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
    "  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
    "  , COUNTIF(ontime != 'ontime') AS total_cancel\n",
    "  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
    "FROM predictions, UNNEST(predicted_ontime_probs) p\n",
    "WHERE p.label = 'ontime'\n",
    ")\n",
    "\n",
    "SELECT\n",
    "   correct_cancel / total_cancel AS correct_cancel\n",
    "   , total_noncancel\n",
    "   , correct_noncancel / total_noncancel AS correct_noncancel\n",
    "   , total_cancel\n",
    "   , rmse\n",
    "FROM stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AutoML (optional)\n",
    "\n",
    "Let's try AutoML Tables, which should give us close to state-of-the-art performance.\n",
    "Note, however, that since a custom data split is not supported by Auto ML, we can not really\n",
    "compare performance across methods.\n",
    "\n",
    "**This will take ~60 minutes. You can skip this step.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_automl\n",
    "OPTIONS(input_label_cols=['ontime'], \n",
    "        model_type='automl_classifier')\n",
    "AS\n",
    "\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  is_train_day = 'True'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery\n",
    "\n",
    "WITH predictions AS (\n",
    "SELECT \n",
    "  *\n",
    "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_automl,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))),\n",
    "\n",
    "stats AS (\n",
    "SELECT \n",
    "  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
    "  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
    "  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
    "  , COUNTIF(ontime != 'ontime') AS total_cancel\n",
    "  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
    "FROM predictions, UNNEST(predicted_ontime_probs) p\n",
    "WHERE p.label = 'ontime'\n",
    ")\n",
    "\n",
    "SELECT\n",
    "   correct_cancel / total_cancel AS correct_cancel\n",
    "   , total_noncancel\n",
    "   , correct_noncancel / total_noncancel AS correct_noncancel\n",
    "   , total_cancel\n",
    "   , rmse\n",
    "FROM stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2021 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "",
   "name": "managed-notebooks.m82",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/managed-notebooks:m82"
  },
  "kernelspec": {
   "display_name": "",
   "name": ""
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
